{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "28S76DVlfCMZ"
      },
      "source": [
        "# Adversarial Reprogramming of MNIST Neural Cellular Automata\n",
        "\n",
        "This notebook contains code to reproduce experiments and figures regarding MNIST CAs for the \"Adversarial Reprogramming of Neural Cellular Automata\" article.\n",
        "\n",
        "*Copyright 2020 Google LLC*\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "i5wi_r4gyzFr"
      },
      "outputs": [],
      "source": [
        "#@title imports and notebook utils\n",
        "%tensorflow_version 2.x\n",
        "\n",
        "import os\n",
        "import io\n",
        "import PIL.Image, PIL.ImageDraw\n",
        "import base64\n",
        "import zipfile\n",
        "import json\n",
        "import requests\n",
        "import numpy as np\n",
        "import matplotlib.pylab as pl\n",
        "import matplotlib\n",
        "import glob\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "from IPython.display import Image, HTML, clear_output\n",
        "import tqdm\n",
        "\n",
        "import os\n",
        "os.environ['FFMPEG_BINARY'] = 'ffmpeg'\n",
        "import moviepy.editor as mvp\n",
        "from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter\n",
        "clear_output()\n",
        "\n",
        "def np2pil(a):\n",
        "  if a.dtype in [np.float32, np.float64]:\n",
        "    a = np.uint8(np.clip(a, 0, 1)*255)\n",
        "  return PIL.Image.fromarray(a)\n",
        "\n",
        "def imwrite(f, a, fmt=None):\n",
        "  a = np.asarray(a)\n",
        "  if isinstance(f, str):\n",
        "    fmt = f.rsplit('.', 1)[-1].lower()\n",
        "    if fmt == 'jpg':\n",
        "      fmt = 'jpeg'\n",
        "    f = open(f, 'wb')\n",
        "  np2pil(a).save(f, fmt, quality=95)\n",
        "\n",
        "def imencode(a, fmt='jpeg'):\n",
        "  a = np.asarray(a)\n",
        "  if len(a.shape) == 3 and a.shape[-1] == 4:\n",
        "    fmt = 'png'\n",
        "  f = io.BytesIO()\n",
        "  imwrite(f, a, fmt)\n",
        "  return f.getvalue()\n",
        "\n",
        "def im2url(a, fmt='jpeg'):\n",
        "  encoded = imencode(a, fmt)\n",
        "  base64_byte_string = base64.b64encode(encoded).decode('ascii')\n",
        "  return 'data:image/' + fmt.upper() + ';base64,' + base64_byte_string\n",
        "\n",
        "def imshow(a, fmt='jpeg'):\n",
        "  display(Image(data=imencode(a, fmt)))\n",
        "\n",
        "def tile2d(a, w=None):\n",
        "  a = np.asarray(a)\n",
        "  if w is None:\n",
        "    w = int(np.ceil(np.sqrt(len(a))))\n",
        "  th, tw = a.shape[1:3]\n",
        "  pad = (w-len(a))%w\n",
        "  a = np.pad(a, [(0, pad)]+[(0, 0)]*(a.ndim-1), 'constant')\n",
        "  h = len(a)//w\n",
        "  a = a.reshape([h, w]+list(a.shape[1:]))\n",
        "  a = np.rollaxis(a, 2, 1).reshape([th*h, tw*w]+list(a.shape[4:]))\n",
        "  return a\n",
        "\n",
        "def zoom(img, scale=4):\n",
        "  img = np.repeat(img, scale, 0)\n",
        "  img = np.repeat(img, scale, 1)\n",
        "  return img\n",
        "\n",
        "class VideoWriter:\n",
        "  def __init__(self, filename, fps=30.0, **kw):\n",
        "    self.writer = None\n",
        "    self.params = dict(filename=filename, fps=fps, **kw)\n",
        "\n",
        "  def add(self, img):\n",
        "    img = np.asarray(img)\n",
        "    if self.writer is None:\n",
        "      h, w = img.shape[:2]\n",
        "      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)\n",
        "    if img.dtype in [np.float32, np.float64]:\n",
        "      img = np.uint8(img.clip(0, 1)*255)\n",
        "    if len(img.shape) == 2:\n",
        "      img = np.repeat(img[..., None], 3, -1)\n",
        "    if len(img.shape) == 3 and img.shape[-1] == 4:\n",
        "      img = img[..., :3] * img[..., 3, None]\n",
        "    self.writer.write_frame(img)\n",
        "\n",
        "  def close(self):\n",
        "    if self.writer:\n",
        "      self.writer.close()\n",
        "\n",
        "  def __enter__(self):\n",
        "    return self\n",
        "\n",
        "  def __exit__(self, *kw):\n",
        "    self.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zmOPQeGLt9Uv"
      },
      "source": [
        "##Load MNIST"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2Wo1AApPulqD"
      },
      "outputs": [],
      "source": [
        "# @title Generate train/test set from MNIST.\n",
        "\n",
        "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
        "x_train = np.array(x_train / 255.0,).astype(np.float32)\n",
        "x_test = np.array(x_test / 255.0,).astype(np.float32)\n",
        "\n",
        "# @title Data generator\n",
        "color_lookup = tf.constant([\n",
        "            [128, 0, 0],\n",
        "            [230, 25, 75],\n",
        "            [70, 240, 240],\n",
        "            [210, 245, 60],\n",
        "            [250, 190, 190],\n",
        "            [170, 110, 40],\n",
        "            [170, 255, 195],\n",
        "            [165, 163, 159],\n",
        "            [0, 128, 128],\n",
        "            [128, 128, 0],\n",
        "            [0, 0, 0], # This is the default for digits.\n",
        "            [255, 255, 255] # This is the background.\n",
        "            ])\n",
        "\n",
        "backgroundWhite = True\n",
        "def color_labels(x, y_pic, disable_black=False, dtype=tf.uint8):\n",
        "  # works for shapes of x [b, r, c] and [r, c]\n",
        "  black_and_white = tf.fill(list(x.shape) + [2], 0.01)\n",
        "  is_gray = tf.cast(x \u003e 0.1, tf.float32)\n",
        "  is_not_gray = 1. - is_gray\n",
        "\n",
        "  y_pic = y_pic * tf.expand_dims(is_gray, -1) # forcibly cancels everything outside of it.\n",
        "  \n",
        "  # if disable_black, make is_gray super low.\n",
        "  if disable_black:\n",
        "    is_gray *= -1e5\n",
        "    # this ensures that you don't draw white in the digits.\n",
        "    is_not_gray += is_gray\n",
        "\n",
        "  bnw_order = [is_gray, is_not_gray] if backgroundWhite else [is_not_gray, is_gray]\n",
        "  black_and_white *= tf.stack(bnw_order, -1)\n",
        "\n",
        "  rgb = tf.gather(\n",
        "      color_lookup,\n",
        "      tf.argmax(tf.concat([y_pic, black_and_white], -1), -1))\n",
        "  if dtype == tf.uint8:\n",
        "    return tf.cast(rgb, tf.uint8)\n",
        "  else:\n",
        "    return tf.cast(rgb, dtype) / 255.\n",
        "\n",
        "def to_ten_dim_label(x, y):\n",
        "  # x shape is [b, r, c]\n",
        "  # y shape is [b]\n",
        "\n",
        "  # y_res shape is [b, r, c, 10]\n",
        "  y_res = np.zeros(list(x.shape) + [10])\n",
        "  # broadcast y to match x shape:\n",
        "  y_expanded = np.broadcast_to(y, x.T.shape).T\n",
        "  y_res[x \u003e= 0.1, y_expanded[x \u003e= 0.1]] = 1.0\n",
        "  return y_res.astype(np.float32)\n",
        "\n",
        "# Hijack the target to be always 8\n",
        "def to_ten_dim_label_hijacked(x, fixed_y):\n",
        "  # x shape is [b, r, c]\n",
        "  # y shape is [b]\n",
        "\n",
        "  # y_res shape is [b, r, c, 10]\n",
        "  y_res = np.zeros(list(x.shape) + [10])\n",
        "  # broadcast y to match x shape:\n",
        "  y_expanded = np.broadcast_to(fixed_y, x.T.shape).T\n",
        "  y_res[x \u003e= 0.1, y_expanded[x \u003e= 0.1]] = 1.0\n",
        "  return y_res.astype(np.float32)\n",
        "\n",
        "\n",
        "def find_different_numbers(x_set, y_set, y_set_pic, orientation=\"vertical\"):\n",
        "  result_y = []\n",
        "  result_x = []\n",
        "  for i in range(10):\n",
        "    for x, y, y_pic in zip(x_set, y_set, y_set_pic):\n",
        "      if y == i:\n",
        "        result_y.append(color_labels(x, y_pic))\n",
        "        result_x.append(x)\n",
        "        break\n",
        "  assert len(result_y) == 10\n",
        "\n",
        "  result_y = np.concatenate(result_y, axis=0 if orientation == \"vertical\" else 1)\n",
        "  result_x = np.stack(result_x)\n",
        "\n",
        "  return result_y, result_x\n",
        "\n",
        "\n",
        "print(\"Generating y pics...\")\n",
        "y_train_pic = to_ten_dim_label(x_train, y_train)\n",
        "y_test_pic = to_ten_dim_label(x_test, y_test)\n",
        "\n",
        "y_train_adv_pic = to_ten_dim_label_hijacked(x_train, 8)\n",
        "\n",
        "numbers_legend, x_legend = find_different_numbers(x_train, y_train, y_train_pic)\n",
        "numbers_legend_horiz, _ = find_different_numbers(x_train, y_train, y_train_pic, \"horizontal\")\n",
        "\n",
        "imshow(zoom(numbers_legend_horiz))\n",
        "print(\"Storing legend as image for demo.\")\n",
        "matplotlib.image.imsave(\"legend.png\", numbers_legend_horiz)\n",
        "\n",
        "print(\"Storing x_legend for use in the demo.\")\n",
        "import json\n",
        "samples_str = json.dumps(x_legend.tolist())\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O-ssztk0fwi5"
      },
      "source": [
        "## Cellular Automata parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zR6I1JONmWBb"
      },
      "outputs": [],
      "source": [
        "#@markdown ### Model configuration\n",
        "#@markdown These options configure the model to be used and train in this \n",
        "#@markdown notebook. Please refer to the article for more information.\n",
        "CHANNEL_N = 19 # Number of CA state channels\n",
        "BATCH_SIZE = 16\n",
        "POOL_SIZE = BATCH_SIZE * 10\n",
        "CELL_FIRE_RATE = 0.5\n",
        "\n",
        "MODEL_TYPE = '2 mutating'  #@param ['1 persistent', '2 mutating']\n",
        "LOSS_TYPE = \"l2\"  #@param ['l2', 'ce']\n",
        "ADD_NOISE = \"True\"  #@param ['True', 'False']\n",
        "\n",
        "MUTATE_POOL = MODEL_TYPE == '2 mutating'\n",
        "ADD_NOISE = ADD_NOISE == 'True'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lCbPFbI_zosW"
      },
      "outputs": [],
      "source": [
        "#@title CA model and utils\n",
        "\n",
        "from tensorflow.keras.layers import Conv2D\n",
        "\n",
        "class CAModel(tf.keras.Model):\n",
        "\n",
        "  def __init__(self, channel_n=CHANNEL_N, fire_rate=CELL_FIRE_RATE,\n",
        "               add_noise=ADD_NOISE):\n",
        "    # CHANNEL_N does *not* include the greyscale channel.\n",
        "    # but it does include the 10 possible outputs.\n",
        "    super().__init__()\n",
        "    self.channel_n = channel_n\n",
        "    self.fire_rate = fire_rate\n",
        "    self.add_noise = add_noise\n",
        "\n",
        "    self.perceive = tf.keras.Sequential([\n",
        "          Conv2D(80, 3, activation=tf.nn.relu, padding=\"SAME\"),\n",
        "      ])\n",
        "\n",
        "    self.dmodel = tf.keras.Sequential([\n",
        "          Conv2D(80, 1, activation=tf.nn.relu),\n",
        "          Conv2D(self.channel_n, 1, activation=None,\n",
        "                       kernel_initializer=tf.zeros_initializer),\n",
        "    ])\n",
        "\n",
        "    self(tf.zeros([1, 3, 3, channel_n + 1]))  # dummy calls to build the model\n",
        "\n",
        "  @tf.function\n",
        "  def call(self, x, fire_rate=None, manual_noise=None):\n",
        "    gray, state = tf.split(x, [1, self.channel_n], -1)\n",
        "    ds = self.dmodel(self.perceive(x))\n",
        "    if self.add_noise:\n",
        "      if manual_noise is None:\n",
        "        residual_noise = tf.random.normal(tf.shape(ds), 0., 0.02)\n",
        "      else:\n",
        "        residual_noise = manual_noise\n",
        "      ds += residual_noise\n",
        "\n",
        "    if fire_rate is None:\n",
        "      fire_rate = self.fire_rate\n",
        "    update_mask = tf.random.uniform(tf.shape(x[:, :, :, :1])) \u003c= fire_rate\n",
        "    living_mask = gray \u003e 0.1\n",
        "    residual_mask = update_mask \u0026 living_mask\n",
        "    ds *= tf.cast(residual_mask, tf.float32)\n",
        "    state += ds\n",
        "    \n",
        "    return tf.concat([gray, state], -1)\n",
        "\n",
        "  @tf.function\n",
        "  def initialize(self, images):\n",
        "    state = tf.zeros([tf.shape(images)[0], 28, 28, self.channel_n])\n",
        "    images = tf.reshape(images, [-1, 28, 28, 1])\n",
        "    return tf.concat([images, state], -1)\n",
        "\n",
        "  @tf.function\n",
        "  def classify(self, x):\n",
        "    # The last 10 layers are the classification predictions, one channel\n",
        "    # per class. Keep in mind there is no \"background\" class,\n",
        "    # and that any loss doesn't propagate to \"dead\" pixels.\n",
        "    return x[:,:,:,-10:]\n",
        "\n",
        "CAModel().perceive.summary()\n",
        "CAModel().dmodel.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xDX5HL7VLd0z"
      },
      "source": [
        "# Training and visualization utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IeWf6HeTe8kM"
      },
      "outputs": [],
      "source": [
        "#@title Train utils (SamplePool, model export, visualizations)\n",
        "from google.protobuf.json_format import MessageToDict\n",
        "from tensorflow.python.framework import convert_to_constants\n",
        "\n",
        "class SamplePool:\n",
        "  def __init__(self, *, _parent=None, _parent_idx=None, **slots):\n",
        "    self._parent = _parent\n",
        "    self._parent_idx = _parent_idx\n",
        "    self._slot_names = slots.keys()\n",
        "    self._size = None \n",
        "    for k, v in slots.items():\n",
        "      if self._size is None:\n",
        "        self._size = len(v)\n",
        "      assert self._size == len(v)\n",
        "      setattr(self, k, np.asarray(v))\n",
        "\n",
        "  def sample(self, n):\n",
        "    idx = np.random.choice(self._size, n, False)\n",
        "    batch = {k: getattr(self, k)[idx] for k in self._slot_names}\n",
        "    batch = SamplePool(**batch, _parent=self, _parent_idx=idx)\n",
        "    return batch\n",
        "\n",
        "  def commit(self):\n",
        "    for k in self._slot_names:\n",
        "      getattr(self._parent, k)[self._parent_idx] = getattr(self, k)\n",
        "\n",
        "def export_model(ca, base_fn):\n",
        "  ca.save_weights(base_fn)\n",
        "\n",
        "  cf = ca.call.get_concrete_function(\n",
        "      x=tf.TensorSpec([None, None, None, CHANNEL_N+1]),\n",
        "      fire_rate=tf.constant(0.5),\n",
        "      manual_noise=tf.TensorSpec([None, None, None, CHANNEL_N]))\n",
        "  cf = convert_to_constants.convert_variables_to_constants_v2(cf)\n",
        "  graph_def = cf.graph.as_graph_def()\n",
        "  graph_json = MessageToDict(graph_def)\n",
        "  graph_json['versions'] = dict(producer='1.14', minConsumer='1.14')\n",
        "  model_json = {\n",
        "      'format': 'graph-model',\n",
        "      'modelTopology': graph_json,\n",
        "      'weightsManifest': [],\n",
        "  }\n",
        "  with open(base_fn+'.json', 'w') as f:\n",
        "    json.dump(model_json, f)\n",
        "\n",
        "def classify_and_color(ca, x, disable_black=False):\n",
        "  return color_labels(\n",
        "      x[:,:,:,0], ca.classify(x), disable_black, dtype=tf.float32)\n",
        "\n",
        "\n",
        "def generate_tiled_figures(figures, fade_by=0.1):\n",
        "  tiled_pool = tile2d(figures)\n",
        "  fade_sz = int(tiled_pool.shape[0] * fade_by)\n",
        "  fade = np.linspace(1.0, 0.0, fade_sz)\n",
        "  ones = np.ones(fade_sz) \n",
        "  tiled_pool[:, :fade_sz] += (-tiled_pool[:, :fade_sz] + ones[None, :, None]) * fade[None, :, None] \n",
        "  tiled_pool[:, -fade_sz:] += (-tiled_pool[:, -fade_sz:] + ones[None, :, None]) * fade[None, ::-1, None]\n",
        "  tiled_pool[:fade_sz, :] += (-tiled_pool[:fade_sz, :] + ones[:, None, None]) * fade[:, None, None]\n",
        "  tiled_pool[-fade_sz:, :] += (-tiled_pool[-fade_sz:, :] + ones[:, None, None]) * fade[::-1, None, None]\n",
        "  return tiled_pool\n",
        "\n",
        "def generate_pool_figures(ca, pool, step_i):\n",
        "  tiled_pool = tile2d(classify_and_color(ca, pool.x))\n",
        "  fade = np.linspace(1.0, 0.0, 72)\n",
        "  ones = np.ones(72) \n",
        "  tiled_pool[:, :72] += (-tiled_pool[:, :72] + ones[None, :, None]) * fade[None, :, None] \n",
        "  tiled_pool[:, -72:] += (-tiled_pool[:, -72:] + ones[None, :, None]) * fade[None, ::-1, None]\n",
        "  tiled_pool[:72, :] += (-tiled_pool[:72, :] + ones[:, None, None]) * fade[:, None, None]\n",
        "  tiled_pool[-72:, :] += (-tiled_pool[-72:, :] + ones[:, None, None]) * fade[::-1, None, None]\n",
        "  imwrite('train_log/%04d_pool.jpg'%step_i, tiled_pool)\n",
        "\n",
        "def visualize_batch(ca, x0, x, step_i):\n",
        "  vis0 = np.hstack(classify_and_color(ca, x0).numpy())\n",
        "  vis1 = np.hstack(classify_and_color(ca, x).numpy())\n",
        "  vis = np.vstack([vis0, vis1])\n",
        "  imwrite('train_log/batches_%04d.jpg'%step_i, vis)\n",
        "  print('batch (before/after):')\n",
        "  imshow(vis)\n",
        "\n",
        "def plot_loss(loss_log):\n",
        "  pl.figure(figsize=(10, 4))\n",
        "  pl.title('Loss history (log10)')\n",
        "  pl.plot(np.log10(loss_log), '.', alpha=0.1)\n",
        "  pl.show()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qdfCyWWwZFWq"
      },
      "source": [
        "# Training an adversary\n",
        "\n",
        "We generate an adversary whose goal is to generate total agreement of the classification \"8\"."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u57S5tCxhcNp"
      },
      "source": [
        "# Pretrained models (Original)\n",
        "\n",
        "Please run the cell below to download pretrained models to hijack."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MqxOYKrFS8fU"
      },
      "outputs": [],
      "source": [
        "!wget -O models.zip 'https://github.com/google-research/self-organising-systems/blob/master/assets/mnist_ca/models.zip?raw=true'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ONAhIfBh5X0N"
      },
      "outputs": [],
      "source": [
        "!unzip -oq \"models.zip\" -d \"models\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wiGl7S0E6-OA"
      },
      "outputs": [],
      "source": [
        "def get_exp_path(\n",
        "    prefix, use_sample_pool, mutate_pool, loss_type, add_noise):\n",
        "  path = prefix\n",
        "  path += 'use_sample_pool_%r mutate_pool_%r '%(use_sample_pool, mutate_pool)\n",
        "  path += 'loss_type_%s '%(loss_type)\n",
        "  path += 'add_noise_%r'%(add_noise)\n",
        "  path += '/0100000'\n",
        "  return path\n",
        "\n",
        "def get_model(use_sample_pool=True, mutate_pool=True, loss_type=\"l2\", add_noise=True,\n",
        "              prefix=\"models/\", output='model'):\n",
        "  path = get_exp_path(\n",
        "      prefix, use_sample_pool, mutate_pool, loss_type, add_noise)\n",
        "  assert output in ['model', 'json']\n",
        "  if output == 'model':\n",
        "    ca = CAModel(add_noise=add_noise)\n",
        "    ca.load_weights(path)\n",
        "    return ca\n",
        "  elif output == 'json':\n",
        "    return open(path+'.json', 'r').read()\n",
        "\n",
        "def get_local_model(path, output='model'):\n",
        "  assert output in ['model', 'json']\n",
        "  if output == 'model':\n",
        "    ca = CAModel()\n",
        "    ca.load_weights(path)\n",
        "    return ca\n",
        "  elif output == 'json':\n",
        "    return open(path+'.json', 'r').read()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "btvTxR82O1Wu"
      },
      "source": [
        "## Actual training procedure.\n",
        "\n",
        "Run this section only if you want to train new adversaries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ak5rBmbxmHV7"
      },
      "outputs": [],
      "source": [
        "# Initialize things for a new training run\n",
        "\n",
        "orig_ca = get_model()\n",
        "adv_ca = CAModel()\n",
        "\n",
        "@tf.function\n",
        "def individual_l2_loss(x, y):\n",
        "  # Note how classify is actually not ca specific: it could be a static method.\n",
        "  t = y - orig_ca.classify(x)\n",
        "  return tf.reduce_sum(t**2, [1, 2, 3]) / 2\n",
        "\n",
        "@tf.function\n",
        "def batch_l2_loss(x, y):\n",
        "  return tf.reduce_mean(individual_l2_loss(x, y))\n",
        "\n",
        "\n",
        "assert LOSS_TYPE in [\"l2\", \"ce\"]\n",
        "loss_fn = batch_l2_loss if LOSS_TYPE == \"l2\" else batch_ce_loss\n",
        "\n",
        "loss_log = []\n",
        "adv_loss_log = []\n",
        "\n",
        "lr = 1e-3 \n",
        "lr_sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay(\n",
        "      [30000, 70000], [lr, lr*0.1, lr*0.01])\n",
        "trainer = tf.keras.optimizers.Adam(lr_sched)\n",
        "\n",
        "\n",
        "#@markdown Recommended: 0.1 percentage_virus\n",
        "percentage_virus = 0.1 #@param [0.1] {allow-input: true}\n",
        "\n",
        "starting_indexes = np.random.randint(0, x_train.shape[0]-1, size=POOL_SIZE)\n",
        "initial_random_mask = tf.cast(\n",
        "    tf.random.uniform([POOL_SIZE, 28, 28, 1]) \u003c percentage_virus, tf.float32).numpy()\n",
        "# The target will ALWAYS be 8. But we give it already shaped for optimization purposes.\n",
        "pool = SamplePool(x=orig_ca.initialize(x_train[starting_indexes]).numpy(),\n",
        "                  yadv=y_train_adv_pic[starting_indexes],\n",
        "                  y=y_train_pic[starting_indexes],\n",
        "                  m=initial_random_mask)\n",
        "\n",
        "!mkdir -p train_log \u0026\u0026 rm -f train_log/*\n",
        "!mkdir -p train_log_orig \u0026\u0026 rm -f train_log_orig/*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QzP_vDchq0d9"
      },
      "outputs": [],
      "source": [
        "#@title Training loop {vertical-output: true}\n",
        "\n",
        "@tf.function\n",
        "def train_step(x, y, yadv, m):\n",
        "  iter_n = 20\n",
        "  with tf.GradientTape() as g:\n",
        "    for i in tf.range(iter_n):\n",
        "      x_orig = orig_ca(x)\n",
        "      x_adv = adv_ca(x)\n",
        "      x = x_orig * (1. - m) + x_adv * m\n",
        "    adv_loss = batch_l2_loss(x, yadv)\n",
        "\n",
        "  adv_grads = g.gradient(adv_loss, adv_ca.weights)\n",
        "  adv_grads = [g/(tf.norm(g)+1e-8) for g in adv_grads]\n",
        "  trainer.apply_gradients(zip(adv_grads, adv_ca.weights))\n",
        "\n",
        "  return x, adv_loss\n",
        "\n",
        "for i in range(1, 100000+1):\n",
        "  batch = pool.sample(BATCH_SIZE)\n",
        "  x0 = np.copy(batch.x)\n",
        "  y0 = batch.y\n",
        "  yadv0 = batch.yadv\n",
        "  m0 = batch.m\n",
        "  # we want half of them new. We remove 1/4 from the top and 1/4 from the\n",
        "  # bottom.\n",
        "  q_bs = BATCH_SIZE // 4\n",
        "\n",
        "  new_idx = np.random.randint(0, x_train.shape[0]-1, size=q_bs)\n",
        "  x0[:q_bs] = orig_ca.initialize(x_train[new_idx])\n",
        "  y0[:q_bs] = y_train_pic[new_idx]\n",
        "  yadv0[:q_bs] = y_train_adv_pic[new_idx]\n",
        "  m0[:q_bs] = tf.cast(\n",
        "      tf.random.uniform([q_bs, 28, 28, 1]) \u003c percentage_virus, \n",
        "      tf.float32).numpy()\n",
        "\n",
        "  new_idx = np.random.randint(0, x_train.shape[0]-1, size=q_bs)\n",
        "  new_x, new_y, new_yadv = x_train[new_idx], y_train_pic[new_idx], y_train_adv_pic[new_idx]\n",
        "  if MUTATE_POOL:\n",
        "    new_x = tf.reshape(new_x, [q_bs, 28, 28, 1])\n",
        "    mutate_mask = tf.cast(new_x \u003e 0.1, tf.float32)\n",
        "    mutated_x = tf.concat([new_x, x0[-q_bs:,:,:,1:] * mutate_mask], -1)\n",
        "\n",
        "    x0[-q_bs:] = mutated_x\n",
        "    y0[-q_bs:] = new_y\n",
        "    yadv0[-q_bs:] = new_yadv\n",
        "    # Do not modify the mask!\n",
        "\n",
        "  else:\n",
        "    x0[-q_bs:] = orig_ca.initialize(new_x)\n",
        "    y0[-q_bs:] = new_y\n",
        "    yadv0[-q_bs:] = new_yadv\n",
        "    # Modify the mask too.\n",
        "    m0[-q_bs:] = tf.cast(\n",
        "        tf.random.uniform([q_bs, 28, 28, 1]) \u003c percentage_virus, \n",
        "        tf.float32).numpy()\n",
        "\n",
        "  x, adv_loss = train_step(x0, y0, yadv0, m0)\n",
        "\n",
        "  batch.x[:] = x\n",
        "  # These get reordered.\n",
        "  batch.y[:] = y0\n",
        "  batch.yadv[:] = yadv0 \n",
        "  batch.m[:] = m0\n",
        "  batch.commit()\n",
        "\n",
        "  step_i = len(adv_loss_log)\n",
        "  adv_loss_log.append(adv_loss.numpy())\n",
        "\n",
        "  if step_i%100 == 0:\n",
        "    generate_pool_figures(orig_ca, pool, step_i)\n",
        "  if step_i%200 == 0:\n",
        "    clear_output()\n",
        "    visualize_batch(orig_ca, x0, x, step_i)\n",
        "    #plot_loss(loss_log)\n",
        "    pl.figure(figsize=(10, 4))\n",
        "    pl.title('Loss history (log10)')\n",
        "    pl.plot(np.log10(adv_loss_log), '.', alpha=0.1, label=\"adv_loss\")\n",
        "    pl.legend()\n",
        "    pl.show()\n",
        "  if step_i%10000 == 0:\n",
        "    export_model(adv_ca, 'train_log/%07d'%step_i)\n",
        "\n",
        "  print('\\r step: %d, log10(adv_loss): %.3f'%(\n",
        "      len(loss_log), np.log10(adv_loss)), end='')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2vhKajnuxINp"
      },
      "outputs": [],
      "source": [
        "# useful code if you end up interrupting the run.\n",
        "print(step_i)\n",
        "export_model(adv_ca, 'train_log/%07d'%step_i)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "TAmsFxk9GMWk"
      },
      "outputs": [],
      "source": [
        "#@title Save data for reproducing the tfjs powered Distill demo\n",
        "\n",
        "def get_weights_b64(ca):\n",
        "  return json.dumps([base64.b64encode(v.numpy().tobytes()).decode('ascii') for v in ca.weights])\n",
        "\n",
        "weights_b64 = get_weights_b64(get_model())\n",
        "adv_weights_b64 = get_weights_b64(get_local_model(\"saved_adversarial/saved_model\"))\n",
        "\n",
        "\n",
        "data_js = '''\n",
        "  window.WEIGHTS_B64 = %s;\n",
        "  window.ADV_WEIGHTS_B64 = %s;\n",
        "'''%(weights_b64, adv_weights_b64)\n",
        "\n",
        "with open(\"data.js\", \"w\") as f:\n",
        "  f.write(data_js)\n",
        "\n",
        "!ls data.js\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tXf6_os6KnXv"
      },
      "source": [
        "# Pretrained model (Adversaries)\n",
        "run the cells below to load the adversaries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8P4nqGXj2dKr"
      },
      "outputs": [],
      "source": [
        "!wget -O adversarial_model.zip 'https://github.com/google-research/self-organising-systems/blob/master/adversarial_reprogramming_ca/assets/mnist_ca_adversarial_model.zip?raw=true'\n",
        "\n",
        "!unzip -oq \"adversarial_model.zip\" -d \"saved_adversarial\"\n",
        "\n",
        "adv_ca = get_local_model(\"saved_adversarial/saved_model\", output='model')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tqhQOXuh2kWC"
      },
      "source": [
        "# Demo in this colab"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "46N0wlW44mQD"
      },
      "outputs": [],
      "source": [
        "#@title TensorFlow.js Demo small {run:\"auto\", vertical-output: true}\n",
        "#@markdown Select \"CHECKPOINT\" model to load the checkpoint created by running \n",
        "#@markdown cells from the \"Training\" section of this notebook.\n",
        "#@markdown Technical note: CE models should be rendered differently to avoid\n",
        "#@markdown black pixels showing for low magnitude.\n",
        "#@markdown draw with left click, hold shift for erasing\n",
        "import IPython.display\n",
        "\n",
        "# Copy the legend in the proper place.\n",
        "#!ls /usr/local/share/jupyter/nbextensions/google.colab\n",
        "!cp /content/legend.png /usr/local/share/jupyter/nbextensions/google.colab/\n",
        "\n",
        "# These are the original models' parameters. Modify them manually\n",
        "# if you want to see different behaviors.\n",
        "model_source = \"LOAD\"  # ['CHECKPOINT', 'LOAD']\n",
        "model_type = '3 mutating'  # ['1 naive', '2 persistent', '3 mutating']\n",
        "loss_type = \"l2\"  #['l2', 'ce']\n",
        "add_noise = \"True\"  #['True', 'False']\n",
        "#quant_states = \"True\"  #@param ['True', 'False']\n",
        "\n",
        "if model_source != 'CHECKPOINT':\n",
        "  use_sample_pool, mutate_pool = {\n",
        "      '1 naive': (False, False),\n",
        "      '2 persistent': (True, False),\n",
        "      '3 mutating': (True, True)\n",
        "      }[model_type]\n",
        "  add_noise = add_noise == 'True'\n",
        "\n",
        "  model_str = get_model(\n",
        "      use_sample_pool=use_sample_pool, mutate_pool=mutate_pool,\n",
        "      loss_type=loss_type, add_noise=add_noise,\n",
        "      output='json')\n",
        "else:\n",
        "  last_checkpoint_fn = sorted(glob.glob('train_log_orig/*.json'))[-1]\n",
        "  model_str = open(last_checkpoint_fn).read()\n",
        "\n",
        "adversarial_model_source = \"LOAD\"  #@param ['CHECKPOINT', 'LOAD']\n",
        "\n",
        "if adversarial_model_source != \"CHECKPOINT\":\n",
        "  adv_model_str = get_local_model(\"saved_adversarial/saved_model\",output=\"json\")\n",
        "else:\n",
        "  last_checkpoint_fn = sorted(glob.glob('train_log/*.json'))[-1]\n",
        "  adv_model_str = open(last_checkpoint_fn).read()\n",
        "\n",
        "data_js = '''\n",
        "  window.GRAPH_URL = URL.createObjectURL(new Blob([`%s`], {type: 'application/json'}));\n",
        "  window.ADV_GRAPH_URL = URL.createObjectURL(new Blob([`%s`], {type: 'application/json'}));\n",
        "  window.SAMPLES = %s\n",
        "'''%(model_str, adv_model_str, samples_str)\n",
        "\n",
        "display(IPython.display.Javascript(data_js))\n",
        "\n",
        "\n",
        "IPython.display.HTML('''\n",
        "\u003cscript src=\\\"https://unpkg.com/@tensorflow/tfjs@latest/dist/tf.min.js\\\"\u003e\u003c/script\u003e\n",
        "\u003cscript src=\"https://cdnjs.cloudflare.com/ajax/libs/cash/4.1.2/cash.min.js\"\u003e\u003c/script\u003e\n",
        "\u003cdiv\u003e\u003cimg src='/nbextensions/google.colab/legend.png' /\u003e\u003cdiv\u003e\n",
        "\u003ccanvas id='canvas' style=\"border: 1px solid black; image-rendering: pixelated;\"\u003e\u003c/canvas\u003e\n",
        "\u003cdiv class=\"slidecontainer\"\u003e\n",
        "    brushSize:\n",
        "    \u003cinput type=\"range\" min=\"1\" max=\"10\" value=\"4\" class=\"slider\" id=\"brushSlider\"\u003e\n",
        "    \u003cspan id='radius'\u003e2.5\u003c/span\u003e\n",
        "\u003c/div\u003e\n",
        "\u003cdiv\u003e\u003cbutton type=\"button\" id=\"removeadv\"\u003eRemove adversaries\u003c/button\u003e\u003c/div\u003e\n",
        "\n",
        "\u003cdiv class=\"boxcontainer\"\u003e\n",
        "\u003cinput type=\"checkbox\" id=\"drawadversary\" name=\"drawadversary\"\u003e\n",
        "\u003clabel for=\"drawadversary\"\u003eDraw adversary\u003c/label\u003e\u003cbr\u003e\n",
        "\u003c/div\u003e\n",
        "\u003cdiv class=\"boxcontainer\"\u003e\n",
        "\u003cinput type=\"checkbox\" id=\"showadvmask\" name=\"showadvmask\"\u003e\n",
        "\u003clabel for=\"showadvmask\"\u003eShow adversary mask\u003c/label\u003e\u003cbr\u003e\n",
        "\u003c/div\u003e\n",
        "\u003cscript\u003e\n",
        "  \"use strict\";\n",
        "\n",
        "  // Adds the WASM backend to the global backend registry.\n",
        "  //import '@tensorflow/tfjs-backend-wasm';\n",
        "  // Set the backend to WASM and wait for the module to be ready.\n",
        "const main = async () =\u003e {\n",
        "\n",
        "  const sleep = (ms)=\u003enew Promise(resolve =\u003e setTimeout(resolve, ms));\n",
        "  \n",
        "  const parseConsts = model_graph=\u003e{\n",
        "    const dtypes = {'DT_INT32':['int32', 'intVal', Int32Array],\n",
        "                    'DT_FLOAT':['float32', 'floatVal', Float32Array]};\n",
        "    \n",
        "    const consts = {};\n",
        "    model_graph.modelTopology.node.filter(n=\u003en.op=='Const').forEach((node=\u003e{\n",
        "      const v = node.attr.value.tensor;\n",
        "      const [dtype, field, arrayType] = dtypes[v.dtype];\n",
        "      if (!v.tensorShape.dim) {\n",
        "        consts[node.name] = [tf.scalar(v[field][0], dtype)];\n",
        "      } else {\n",
        "        const shape = v.tensorShape.dim.map(d=\u003eparseInt(d.size));\n",
        "        let arr;\n",
        "        if (v.tensorContent) {\n",
        "          const data = atob(v.tensorContent);\n",
        "          const buf = new Uint8Array(data.length);\n",
        "          for (var i=0; i\u003cdata.length; ++i) {\n",
        "            buf[i] = data.charCodeAt(i);\n",
        "          }\n",
        "          arr = new arrayType(buf.buffer);\n",
        "        } else {\n",
        "          const size = shape.reduce((a, b)=\u003ea*b);\n",
        "          arr = new arrayType(size);\n",
        "          arr.fill(v[field][0]);\n",
        "        }\n",
        "        consts[node.name] = [tf.tensor(arr, shape, dtype)];\n",
        "      }\n",
        "    }));\n",
        "    return consts;\n",
        "  }\n",
        "  \n",
        "  let paused = false;\n",
        "  let visibleChannel = -1;\n",
        "  let firingChance = 0.5;\n",
        "  let drawRadius = 2.5;\n",
        "\n",
        "  let drawadversaryCkbx = document.getElementById(\"drawadversary\");\n",
        "  let showadvmaskCkbx = document.getElementById(\"showadvmask\");\n",
        "\n",
        "  $('#brushSlider').on('input', e=\u003e{\n",
        "      drawRadius = parseFloat(e.target.value)/2.0;\n",
        "      $('#radius').text(drawRadius);\n",
        "  });\n",
        "\n",
        "  const colorLookup = tf.tensor([\n",
        "      [128, 0, 0],\n",
        "      [230, 25, 75],\n",
        "      [70, 240, 240],\n",
        "      [210, 245, 60],\n",
        "      [250, 190, 190],\n",
        "      [170, 110, 40],\n",
        "      [170, 255, 195],\n",
        "      [165, 163, 159],\n",
        "      [0, 128, 128],\n",
        "      [128, 128, 0],\n",
        "      [0, 0, 0], // This is the default for digits.\n",
        "      [255, 255, 255] // This is the background.\n",
        "      ])\n",
        "\n",
        "  let backgroundWhite = true;\n",
        "\n",
        "\n",
        "  const run = async () =\u003e {\n",
        "      const r = await fetch(GRAPH_URL);\n",
        "      const consts = parseConsts(await r.json());\n",
        "\n",
        "      //const samples = tf.tensor(SAMPLES);\n",
        "      //console.log(samples);\n",
        "\n",
        "      const model = await tf.loadGraphModel(GRAPH_URL);\n",
        "\n",
        "      const samples = tf.tensor(SAMPLES);\n",
        "      console.log(samples.shape);\n",
        "      //const samples = tf.zeros([2,5, 28, 28]);\n",
        "\n",
        "      console.log(\"Loaded model\")\n",
        "      Object.assign(model.weights, consts);\n",
        "\n",
        "      // Adversarial model now.\n",
        "      const rad = await fetch(ADV_GRAPH_URL);\n",
        "      const adv_consts = parseConsts(await rad.json());\n",
        "      const adv_model = await tf.loadGraphModel(ADV_GRAPH_URL);\n",
        "\n",
        "      console.log(\"Loaded adv model\")\n",
        "      Object.assign(adv_model.weights, adv_consts);\n",
        "\n",
        "\n",
        "      // samples.gather(tf.range(0, 4, 1, 'int32')\n",
        "      const D = 28 * 2;\n",
        "      const state = tf.variable(tf.zeros([1, D, D, 20]));\n",
        "      // this is where we keep track of where is which CA.\n",
        "      const adv_mask = tf.variable(tf.zeros([1, D, D, 1]));\n",
        "      // store this to avoid recomputations.\n",
        "      const orig_mask = tf.variable(tf.ones([1, D, D, 1]));\n",
        "      const [_, h, w, ch] = state.shape;\n",
        "\n",
        "      $('#removeadv').on('click', e=\u003e{\n",
        "          tf.tidy(()=\u003e{\n",
        "            adv_mask.assign(tf.zeros([1, D, D, 1]));\n",
        "            orig_mask.assign(tf.ones([1, D, D, 1]));\n",
        "          });\n",
        "      });\n",
        "\n",
        "      const scale = 8;\n",
        "\n",
        "      const canvas = document.getElementById('canvas');\n",
        "      const ctx = canvas.getContext('2d');\n",
        "      canvas.width = w * scale;\n",
        "      canvas.height = h * scale;\n",
        "\n",
        "      const drawing_canvas = new OffscreenCanvas(w, h);\n",
        "      const draw_ctx = drawing_canvas.getContext('2d');\n",
        "\n",
        "      // Useful for understanding background color.\n",
        "      \n",
        "      //let blackAndWhite = tf.zeros();//.fill(0.01);\n",
        "      let arr = new Float32Array(h * w * 2);\n",
        "      arr.fill(0.01);\n",
        "      const blackAndWhiteFull = tf.tensor(arr, [1,h,w,2], tf.float32)\n",
        "\n",
        "      const drawCanvas = (imgd, e) =\u003e {\n",
        "          var matrix = [];\n",
        "          for(let i=0; i\u003cimgd.width; i++) {\n",
        "              matrix[i] = [];\n",
        "              for(let j=0; j\u003cimgd.height; j++) {\n",
        "                  let intensity = imgd.data[(imgd.height*j*4 + i*4)];\n",
        "                  // For drawing, we want to add shades of grey. For erasing, we don't.\n",
        "                  if (!e.shiftKey) {\n",
        "                    intensity *= (imgd.data[(imgd.height*j*4 + i*4 + 3)] / 255);\n",
        "                  }\n",
        "                  matrix[i][j] = intensity;\n",
        "              }\n",
        "          }\n",
        "\n",
        "          tf.tidy(() =\u003e {\n",
        "              const stroke = tf.tensor(matrix).transpose().toFloat().div(255.).expandDims(0).expandDims(3);\n",
        "              const stroke_pad = tf.concat([stroke, tf.zeros([1, h, w, ch-1])], 3);\n",
        "              const mask = tf.tensor(1.).sub(stroke);\n",
        "              if (e.shiftKey) {\n",
        "                  state.assign(state.mul(mask));\n",
        "                  // delete adversaries too in that case...\n",
        "                  adv_mask.assign(adv_mask.mul(mask));\n",
        "                  orig_mask.assign(adv_mask.sub(1).mul(-1));\n",
        "              } else {\n",
        "                  state.assign(state.mul(mask).add(stroke_pad));\n",
        "                  if (drawadversaryCkbx.checked == true) {\n",
        "                    adv_mask.assign(adv_mask.mul(mask).add(stroke));\n",
        "                    orig_mask.assign(adv_mask.sub(1).mul(-1));\n",
        "                  }\n",
        "              }\n",
        "          });\n",
        "\n",
        "          // Then clear the canvas.\n",
        "          draw_ctx.clearRect(0, 0, draw_ctx.canvas.width, draw_ctx.canvas.height);\n",
        "      }\n",
        "\n",
        "      const line = (x0, y0, x1, y1, r, e) =\u003e {\n",
        "          draw_ctx.beginPath();\n",
        "          draw_ctx.moveTo(x0, y0);\n",
        "          draw_ctx.lineTo(x1, y1);\n",
        "          draw_ctx.strokeStyle = \"#ff0000\";\n",
        "          // Erasing has a much larger radius.\n",
        "          draw_ctx.lineWidth = (e.shiftKey ? 5. * r : r);\n",
        "          draw_ctx.stroke();\n",
        "\n",
        "          const imgd = draw_ctx.getImageData(0, 0, draw_ctx.canvas.width, draw_ctx.canvas.height);\n",
        "          drawCanvas(imgd, e);\n",
        "      }\n",
        "\n",
        "\n",
        "      const circle = (x, y, r, e) =\u003e {\n",
        "          if (drawadversaryCkbx.checked) {\n",
        "            // perform surgical insertions!\n",
        "            draw_ctx.fillRect(x, y, 1, 1);\n",
        "            const imgd = draw_ctx.getImageData(\n",
        "              0, 0, draw_ctx.canvas.width, draw_ctx.canvas.height);\n",
        "            drawCanvas(imgd, e);\n",
        "            return;\n",
        "          }\n",
        "          draw_ctx.beginPath();\n",
        "\n",
        "          const drawRadius = (e.shiftKey ? 5. * r : r) / 3.;\n",
        "\n",
        "          draw_ctx.arc(x, y, drawRadius, 0, 2 * Math.PI, false);\n",
        "          draw_ctx.fillStyle = \"#ff0000\";\n",
        "          draw_ctx.fill();\n",
        "          draw_ctx.lineWidth = 1;\n",
        "          draw_ctx.strokeStyle = \"#ff0000\";\n",
        "          draw_ctx.stroke();\n",
        "\n",
        "          const imgd = draw_ctx.getImageData(0, 0, draw_ctx.canvas.width, draw_ctx.canvas.height);\n",
        "          drawCanvas(imgd, e);\n",
        "      }\n",
        "\n",
        "      const draw_r = 2.0;\n",
        "\n",
        "\n",
        "      const getClickPos = e=\u003e{\n",
        "          const x = Math.floor((e.pageX-e.target.offsetLeft) / scale);\n",
        "          const y = Math.floor((e.pageY-e.target.offsetTop) / scale);\n",
        "          return [x, y];\n",
        "      }\n",
        "\n",
        "      let lastX = 0;\n",
        "      let lastY = 0;\n",
        "\n",
        "      canvas.onmousedown = e =\u003e {\n",
        "          const [x, y] = getClickPos(e);\n",
        "          lastX = x;\n",
        "          lastY = y;\n",
        "          circle(x,y,drawRadius, e);\n",
        "      }\n",
        "      canvas.onmousemove = e =\u003e {\n",
        "          const [x, y] = getClickPos(e);\n",
        "          if (e.buttons == 1) {\n",
        "              line(lastX,lastY, x,y,drawRadius, e);\n",
        "          }\n",
        "          lastX = x;\n",
        "          lastY = y;\n",
        "      }\n",
        "        \n",
        "      const initT = new Date().getTime() / 1000;\n",
        "      const render = async () =\u003e {\n",
        "        if (!paused) {\n",
        "          tf.tidy(() =\u003e {\n",
        "            const orig_state = model.execute(\n",
        "                  { x: state,\n",
        "                    fire_rate: tf.tensor(firingChance),\n",
        "                    manual_noise: tf.randomNormal([1, h, w, ch-1], 0., 0.02)},\n",
        "                  ['Identity']);\n",
        "            const adv_state = adv_model.execute(\n",
        "                  { x: state,\n",
        "                    fire_rate: tf.tensor(firingChance),\n",
        "                    manual_noise: tf.randomNormal([1, h, w, ch-1], 0., 0.02)},\n",
        "                  ['Identity']);\n",
        "            state.assign(orig_state.mul(orig_mask).add(adv_state.mul(adv_mask)));\n",
        "          });\n",
        "        }\n",
        "        const imageData = tf.tidy(() =\u003e {\n",
        "            let rgbaBytes;\n",
        "            let rgba;\n",
        "            if (showadvmaskCkbx.checked == false) {\n",
        "                const isGray = state.slice([0,0,0,0],[1, h, w, 1]).greater(0.1).toFloat();\n",
        "                const isNotGray = tf.tensor(1.).sub(isGray);\n",
        "\n",
        "                const bnwOrder = backgroundWhite ?  [isGray, isNotGray] : [isNotGray, isGray];\n",
        "                let blackAndWhite = blackAndWhiteFull.mul(tf.concat(bnwOrder, 3));\n",
        "\n",
        "                const grey = state.gather([0], 3).mul(255);\n",
        "                const rgb = tf.gather(colorLookup,\n",
        "                                      tf.argMax(\n",
        "                                      tf.concat([\n",
        "                  state.slice([0,0,0,ch-10],[1,h,w,10]),\n",
        "                  blackAndWhite], 3), 3));\n",
        "\n",
        "                rgba = tf.concat([rgb, grey], 3)\n",
        "\n",
        "                // Add blinking adversary!\n",
        "                //const adv_rgba = adv_mask.gather([0, 0, 0], 3)\n",
        "                //  .pad([[0, 0], [0, 0], [0, 0], [0, 1]], 1).mul(255);\n",
        "                const adv_rgba = tf.tensor([[[[15., 0., 0., 255.]]]]);\n",
        "                const seconds = new Date().getTime() / 1000 - initT;\n",
        "                const t = tf.tensor(seconds).sin().abs();\n",
        "                const onemt = tf.tensor(1.).sub(t);\n",
        "                const adv_mask_period = adv_mask.mul(t);\n",
        "                const img_behind_mask = rgba.mul(adv_mask);\n",
        "                const adv_image = adv_rgba.mul(adv_mask_period).add(\n",
        "                  img_behind_mask.mul(onemt));\n",
        "                \n",
        "                rgba = rgba.mul(orig_mask).add(adv_image);\n",
        "\n",
        "            } else {\n",
        "                rgba = adv_mask.gather([0, 0, 0], 3)\n",
        "                  .pad([[0, 0], [0, 0], [0, 0], [0, 1]], 1).mul(255);\n",
        "            }\n",
        "            rgbaBytes = new Uint8ClampedArray(rgba.dataSync());\n",
        "\n",
        "            return new ImageData(rgbaBytes, w, h);\n",
        "        });\n",
        "        const image = await createImageBitmap(imageData);\n",
        "        //ctx.clearRect(0, 0, canvas.width, canvas.height);\n",
        "        ctx.fillStyle = backgroundWhite ? \"#ffffff\" : \"#000000\";\n",
        "        ctx.fillRect(0, 0, canvas.width, canvas.height);\n",
        "        ctx.imageSmoothingEnabled = false;\n",
        "        ctx.drawImage(image, 0, 0, canvas.width, canvas.height);\n",
        "\n",
        "        requestAnimationFrame(render);\n",
        "      }\n",
        "      render();\n",
        "  }\n",
        "\n",
        "  run();\n",
        "}\n",
        "main();\n",
        "  //tf.setBackend('wasm').then(() =\u003e main());\n",
        "\n",
        "  \n",
        "\u003c/script\u003e\n",
        "''')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QAscSKkRaFwp"
      },
      "source": [
        "# Figures"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Itv0WEuRVWoh"
      },
      "source": [
        "## visualize runs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l1Ri5SmXLqPk"
      },
      "outputs": [],
      "source": [
        "slider_b64 = \"\"\"iVBORw0KGgoAAAANSUhEUgAAAjAAAAAWCAYAAADem6ZtAAAACXBIWXMAAAmOAAAJjgHxlqiVAAAAGXRFWHRTb2Z0d2FyZQB3d3cuaW5rc2NhcGUub3Jnm+48GgAAHZdJREFUeJztnXd4G9eZr98BBr0QjSDABoK9iZQoyZZkyZKs2HKRXPPEcVEcp5fdOMUpe+/ebLK7d5MtSZ69m+x6N7mp3jy5sROvbclVxbKsYlmFogqLxCoSLCDROwbA/QMUJK3kOE7s0Jbx/kM+Mzjn/Gbm4JzvfN93BkIul8v54xLbh4KkpByCIFCkyJVG1UtPUReaQBDFhZZSpMjbh5BF3qpDZStdaCVFirzl5MghiBr0rdch15kQ/XGJJ/sCNNnUKOVF46XIlYkMCdFsRlAoFlpKkSJvG7lMEqVRS0lV2UJLKVLkbSGblggeeRJj123Itg+GisZLkSJFihQpUuQdj0whUlLlIHxyO7KElCsaL0WKFClSpEiRdwUyUQQpgSy30EqKFClSpEiRIkXeDLkssoXW8G5mYvg0vYf3Ewn6F1rKG+L3TpOIRRdaRpE3weG5OQ7Nzi60jDckk8sxHo2SyRWXQ0XemEw2y54jQ7x6fHShpbwhaSnD6KSfYte+PL5gjD1HhugbnlmQ9osGzB9B3+H97Nn6GL6ZyYWW8juZm5rgsX/9e7b9/N8WWkqRN8HTY2M8PTa20DLekN2TkzzS18fuyXf29+DtIpvL4U+lCKbTCy3lXYEkZXl02yF+s/3YQkt5Qx5/8Rh/96MXeenQ6YWW8o5kYibIo9sOsefI4IK0X9xT+h7AYLZS174ER5V7oaUUuQKpNxrxxGI0lJQstJQFIZbJ8J3jx7Gq1XyhrW2h5RR5C+loLCcQjtPosi+0lCKX4V1vwPi900yNDVFWVUM8EmZmYhSj2UZt2+LCO23SqRRDp7qJBPwYzBbcLR0olKrL1pdOJhk61U3I70NvLKG2fQkqtaZwfmJogJmJUSxl5ReVG+k7TjwaoaapHY3eQDIRZ+hkN0q1hrq2xfmyw6eZGhtCrdVT3dCCwWQBYPBkN6lEnMq6JgZPHKWubTEGs5Xps8N4RgaRyWQ4XLWUVdYU2gvOeRkdOEk6mcTmrKC6sQ1BEAgHfIwP9mMpcxILh0jEItS2LaG8ph6NTl8on5HSDJ08Rsg/h9Fipba1E7mY32KcyUiM9B4nMDuDWqvD3dqBVm/84x9WEeaSSfqCQZKShEOjocVs5lwKvZTN0uPzEU6naTGZCmXS2SzdPh8KmYzFlnyfmYzHGY9GcWo0VOp0SLkcJ/x+fIkEFpWKdosFURCISxInAgFMSiUKmYzRSIS1DgeZXI5TgQDeeBytKNJqNmO8YIv5YCjEaDSKQhCoLynBqcl/B4bDYWaTSWr1ek6Hw5Sp1ahFkSaTCfGCd0gFUylOBgKkMhmq9HrqDIbCuYgkccLnIyZJ2DQaWktKEGWXOoO7fT7I5ajU6znl96MVRRaZzajkchKZDMf9/vx1CQIj0SjrHA4AxiIRBsNhVDIZjSUl2NTqQp3jsRhDoRC5XA6XXk/NBbpS2SzHfT6CqRRWtZo2kwlRJiOUStEfClGqUiHK5QyGQmjlcrpsNlKZDN3zYb6EJPHa7CydZjNKufwP7SLvWjzeEMdPe0hLGWrKLbTVOQpjcCKZ5tCps0RiKdrrHYUy0XiKI73jaFQKlrVVATA0PsfETBB3hYXKMhOJZJrDveP4gjEcNgNdLZXIZTJC0QTH+j2UWQ2kpQyT3hDvW9FIWspw+NRZvP4oRr2aruZKDLr8eJ/LwfHTHkYn/ahVIu11Tpyl+bHt5OAUvmCMZredYwMe6qtsaFQK2uudyC/on15/hKN942QyOeqrbTRUn3/nji8Y42jfOLFEmgp7CYubKpDJLt0ks+fIEBq1AofVQM/pSdYvr0ejUuDxBukZmEQQoNldhstpBmBg1Mv0XJj2egf9IzN4/VGaauzUV9noOe1hYiZYaO8ciZTEoZNj+ENx7BY9XS2VKER54f7WlFuocuTHmX3HhslkcqzocKEQ5UzOhjjW70EQoKnGTk25pVDvbCBKd98EcrmAXnv5efRPxbvegJkaG2LP1sfQl5gvykUJzM2wdO1GouEgT/3k+4T9c4Vzh196nls/8ufoDBevGGOREP/1o38mEvSj0mhJxmMc3v0Cd37yi2j1Rrpf2cHBHdsKnxcVysL/02dHOLZvF1I6zaIV1zLWf5I9Wx+jacnV1LUt5pVtj3Pq0L7C5w8qldx03ydxVLs5/NJzBGZn0BlKiIaDlJZXMXiym4M7tuWNJ0EguX0rS9duZOm6jQyf6mH7b34OuRxKlZpkIo67pYPrP/BhfNMe9mx9rFCX1VGBo6qWPVsfw17hoqZ5EclEnKd+8i/4Z6YKerpf2cHtH/08MlHO0z/5ATMTo+iMJhLRCAd3bGPThz5NaUX1W/rs3mv0BYP88swZcoBaFIlLEi0mE/fV1ZHN5fjxwABj0Xye0q6pKTLZLDJBQCGT8crUFHPJJDU6HSaVil0eD6cCAe6vqyOeyfDD/n5m4vFCWy9PT/OppiYiksSTo6PoRJF4JoMok7G6rIwfDwwwGolgVCiIShLPjY/z0aYmqnQ6fjsywpG5OZRyOZlslhc8Hu5yuVhstdLt83F4dhajUkkoleI6pxOZILDd4+H68nKcWi3D4TC/OHOGVDZb0LOitJRN1dXMJBL8e28v6VwOo1JJIJnEqdHw8aamSyb9Z8fHiafTiHI5qUwGgANeL59ubiZ64XVJEjKZjHUOB9snJnhp6ny/fsHj4d66OhqNRvZNT/PM+DgqmQy5TMaLHg9rHA42VlQQkSR+2N/PXCJRKLtfp+PjTU3MpVI8OTqKUaEgIklk5xMiRiIR1peX89zEBEBBU1NJyXvOgNl3bJifP/0agiCgUojEk2lWL6lly6ZlpKUMf/+THXi8IQC27TlZKKdRKXjypeNE4yna6h1oVAqe2NnDwKiXrz64gUA4zj/+dCezgfP5e7WVVh5+YD0zvgiPbjtEiV5NKJrEUqJlTVct//DTnYxPBzAbNYQiSZ7Y0cNXHtyA02bgX3+9l54BD2qVgnQ6w2939PCZD6xmUYOT3YcGOTYwgcmgIRCOc89NXXj9EbYfGOD+W5bhsBk4NjDBfzy+Hylzvm9vXtvGpmvbGDw7y/ce3U02m8OoV+EPxWl0lfL5+9deZAAB/Oczh1Eq5AjkDY2VHTUc7RvnF1sPkc3m+5cgwD03LWXt0jr2Hxth37Hhgrb8fTxFbaWVwbPn8+Q+eGMX65fXMxuI8k8/24k/dH5M2LbnFF9+4DrCsSSPbjvE0tYqPnHXSmYDUX721Gs4bAbWdNVyoGeUnz198CIdd29cwvrlDXi8Qf7xp7uIJVIAKBUL28+vmBwYg8nMli99kxvv/RgAI30nADi06znC/jk6V63nwa99i65rbyAc8HFw+9ZL6pgeG6HEWsqqG+/gga/8LYtWXEssEmLwRDfpZJKje7Yjl4vc9tHPcc9Df4nOeN4Aqp33sowP9V/0t759CdNnRzh1aB9WRwVbHv5rbvvI55AkiVe3P31R+zZnJevvuA9TaRknX9uLXC5y3xf/inse+kvsFS4mR/NxxsmxQcpr6rnzE1/k/oe/icFsZbi3h2g4WKgrlUxw9fs2cdWGWy65zu5XduCfmaLtqtU8+BffYum6jfi905x87RVmJ8eZmRjF3drBfV/4OhvevwWTzc740MAf/GyK5BkOh3EbjXyqpYWvLlqEVaWiNxAgkEzSGwgwFo1SodPxtc5O7q2tvajsIvP8SiwUIgsMhcNo5HLqjUZ2T00xE4+zqqyMry9ezHqnk5l4nIMXJADHMxnWlJVxp8vFdDzOaCRCvdHIVzo6uLeujjKtltOhEHFJIphOs8hi4X90dPCJ5mZyuRz7Zy5O0tOKIre7XLTN67qQp8bGSGezfKihgf+5eDEuvZ5XZ2eZSSQ4OjtLMpvlTpeLh9vbWVVWhkwmY+oCw+FCssAH3G7+oqODGr2eqViMYz5f4XxMkljtcHCXy8VMIsHu6WmsajVfXrSIP2ttRQCePXsWgIOzswjAVzo7ebi9nWq9ntFwmBywY2KCuUSCdQ4HX1+8mFV2O+PRKEfmzi985ILAF9rb+VxrK3JBoDcYxKpU8rXOTgCsajV/u3TpRZ6s9wpnxmZpdpfxzU/fxD984VYMOhV7u4dIJNMc6BnF4w3R7C7jO1+6jS2blhfKyWQCS5oryWZz9A5Nk0hJDJ6dw1qiw11h4cldJ5gNRLl1XTv/56t3sm5ZPUPjcxzoOZ8AHI2n2Ly2jQ9uXMLg+Czj0wGuaq/m2w9t5oFbl2O3GOgdyntX0lKGVZ1uvvvwbXz2g6vJZnPseu3i/BZLiZYHbl1Oi/vilwFmszke3XoYgC9/+Dr+6Yu3UWEv4Zk9vQTCcfZ2D5OWMnzm7mv41uc2sarTTTKVYXoufNl7lkxJLGur5iO3X41MJvCr546iUoh8/ZMb+fZDmzEbtfx2Rw/JlFQos7S1iu99+XbWLq0jl8vljcPPb+Yjt18NwJHefF9/YkcP/lCcTde28c9fuYPrVzQxPRfmyZeO01bnQKtW0js0TTab48SZfO7astZqkimJXz57GKUo53994ga+/dAmLEYdT+w8TiKZ5rm9fcQSKTauaua7D9/Omq66P7TLvCW86z0w5ygtr0ajN+Cszt/QcztuPCNnAGhfcS0KlYrOa67j6J4XmRi6NCnL3dqBxVHO0Mmj7Nn6OL4Zz3xdEQJzM6RTSZyuukIox+aoIDjnnW+/ihKLjcmRQaR0ivHBfjR6A+U19Rx5+UUAlCoVPftfAkBUKJidnCB3QXr76lvuQmc0Feob6TvOkz/+F2pbO1h9y13YnJUArNx4O+OD/Zw5cZRUIk5mPnkwFg4V6qpubKXzmusALvK0AHiG89cupVIceflFEtEIAF7PWZqWXI1KrWG07yQv/L+fUFXfzI33fAyN3kCRP44bKyo4Ewpx0u/nkCSRnvdQRCSJyXnvSafZjF4UaSwpQS6TFfpHh8XCS1NTnAmHsavVJDIZllqtiDIZg6H8c09mMuyanCQi5Qe8iWiUpvm8FKdWy/UVefdyPJNBI5czHA7zn4ODNBiN3F9Xh2F+4r2vro5jc3M8PzFB8pzG/5agurasjEXz4azeQKBwPJhO400kUMvlDIdCDM9ry+VyTMZiVOh0QN67MhqJ0FhSwg3l5ZcNIZ2j3mhEFASW2WyMRCJMxmK49PlwqEOr5Yb569o/M0Mul0Mtk3Fg3uBSymTMJpOkslnKNRpmEwl+2N9Pq8nEzZWVVM7rOTOvMypJ+Xs4f72eWKwQgirTajEr815XrSgSTqcLz/C9zpZNy+k57WHPkUHiyTTZbI5cDsKxJGNTec/4qs4a9FoVnY0Xh9+Xt1Xz8uFBTg1NIZMJZLJZlrVVIQgCvcPTAIQiCbbtOUUkngRgbNJPmTU/JjXV2LllTSsAc4EoKqXIkd5xMtn9tNU5+Ny9awqhjk/ctZKDJ8Z47IVuwtF8XaHoxcbzbevaaXZf+ibjiZkgoWgCo05Nz0B+bpDLZWSyWSZmgrjKLeztHuYXWw/R1VLJ0tZKtmxadtkQEuS9F1s2LQPy4atkSqLUrC/szlKIchLJGDO+SKFMfZUNrVpJTYWF3YcHaaguxWTQUFdlAyhcU/9ovv9fv7IJtVJk46pmXjzQT/+IF1EuY3FTBfuODTM8McepwenCcxganyOZkrCZdBw8MVbQmQxKTPsijE3mn+X65Q3oNEpa3GXseHXhFrdXjAHzusxPAEplfhAS538LJyNdumNgcmSQbb94BK3BSP2iLtTaC3NG8pOCUqW+pNw5atuWcHTPixw/sJt4NELb8msQZDKk+bZCfh/S/MBostpfVwfA+jvu5cSrLzN08hiv7XyW13Y+S8fKday44VYOvPAUxw/spsLdgMNVi/wyK77f9ZtW5/TMeMYQ5/NeSsurUGv1aHR6bv/Y5+neu5OxgZOM9B1n//NPctN9H8fpWlhr+93O8xMTvDI9jVuvx200opDLYb4/SPP9VPU6v9Vk12go02gYDAaxqvKDcYfVmi87P4mOR6MFQ6BCp0NzYV0XGMoauZxPNTezZ3qavmCQ3kCAZ8bH2VJfT7VOxyO9vcwmEnTZbJS8SW/COS1SNstQ5PzAW6HTIQDtZjP319dz0OvlqM/Ha7Oz2DUaPt7YeLHey6CYD8tcZDRccF3n2g6l04W2TSoVJvL5LXe63Th1Onrm5tjp8bDT4+Fqu53NVVVk5usYj8WQzX93KnQ6VK8TCir+ZtzF/Pzp19h3bJj2eifuCgui/LxBKkn556JWXb4vnZuET5yZLNzX5W35cHVayj+ZwfFZZEK+TpfTgkZ9+bqsJh1ffXADL+zv5/hpD4dPnUWtFPnClnWUmvX87x++SCiaYFVnDWaj9k1d4zktKSlT2DYsIOByWsjlclzbVYtGpWBv9xC7Dw2y8+BpasotPPzAehTipf3owj50ru5YIlWoW61U4HJayPwBRnIul69fOd/uuVBPen4eW9ZWxb5jw/Sc9tA3Mk2Vw4TDZmDGF57XkS7oUCrEvI5MtqBzoUNH57jiDRiTzU444GPq7BBV9S1MjQ2Ty+Uw2x2XfHag5xDZbIarNtxC/aIuju3dyWh/PhRlMFsQBIHZyXEyGQm5XLzkvSr17XkDpnvvTiBv0ABYSvNt2ZyV3HD3gwiCgN87jVZvuCiP5hzJeIzuvTspsdi461MPE/TN8l8//B4nXt3D1ddvZqD7IIJMxo33fgy5qGDsdO9FOT5vhKXUgX9mivar1tCydCUA02eHKaty4xk5w9kzfbQuW8m1mz9A35H97Nmaz98pGjB/HEfm5hAEgS0NDShlMk4Hg5x7aueMkrFwmKVWK3FJusg7B/kw0naPh4NeLzpRxD3vhbBrNHgTCVbZ7Sy15VdiY5EI1Xo93suEZkYiEfoDAZbabNzmcnFkdpYnRkc56PUiA2YSCRpLSrjd5SKUSrHrTWyPNqtUKOY9R/fU1mJSKpGyWaYTCSq0Wl71egkkk3zQ7QZB4PGREU75/ZwJhwthsv9OIpNBL4qMzRslFtXlEwfLtNrCvfxIYyMyQSCQSiEDREFgx8QEFrWaz7a24k8meaS/n4NeLzdXVlKqUhFKpVjndNI2n0A9Fo1SrdMxfIEh9rv4QyaaK4FMNsurx0dRKUX+7INrgBwHT4wRjOT7Xqkl308Hz87S2Vhe8KKcQxDyoZEdrw5w8MQYZVZDIbm0vNTIwKiXzWvb6GysKNRTV2XjzNlL35F0cnCK/pEZrl/RyAObl7P91QF+s/0Ye7uHaXGXMReMsqKjhntvXsr4dIAXD/T/3tfpsBmQyQTkMqHg1UmmJGZ8EaocJp7f10c8meahe9eSSKV55LF99I/MMHh2jmb3797FVD6fSKxRKfjSh9ahUopE4ylCkQTOUiO7eXPblM/dt8HxWRqqSzk95p0/nvfItrjL0GtVvHRoMB/Kas0bjOX2/HmN+ryOWCJFMJLAaTNSatYzG4gyPDFHe72TaDz1pnS91VzxBsyiFWsZHxpg+2M/w+mqZ2psCEEQ6Lr2hks+W2LJD/6v7XqW0YGThTwaKZ1CZyihurGN0f4T/Pbfv4NSpWF6fOSi8ma7A7M9bxzojCYc1flty7Xtizm2b1ehrEZvYGpsGFdTGxvu2nKJDqVKzdjASQJzXuamPWQzWVLJBPYKF4IgYLSW4p0Y47lf/ohcLod3Iu/qy2Yyl9R1OZZcez0jfSd45ZnfMNzbQzIex+sZy+ffWEs5vn83Qye7aehYxuRY/otjLybw/tFYVSrORqM8eiYf1jw7n7CbzmZpN5t5YWKCIz4f/lQKXypFJpdDfsEq7ZwBk8xkWGy3FzwF651O+oNBnhwb47jfT0KSGI/F+IDbjVN76SpTJZOx3+ulx+djidXK2VgMgEqdDpNSiVwQGAyF+PXwMBPz59K/55u8ZMCG8nKeGx/nB6dO4TIYmIrFiEoSDy9aRDSdZs/0NBOxGFU6HZ55r5FDo3ndOh/p7cWmVjMYCiEKAostlsvqaTAacen1jEQifL+3F5NSyWg4TKVOx4cbGzkTDjM9M8PUfLgunk5TptEgFwQ2lJczPDDAr4eGqDMaCadSTMbjfKihAcXvCG9BPpykkMkIpFL8eGCAO1wuzK9jZF2JyGUyrCYtM74I3//VHuLJVCHvIy1lWNnh4pk9p3hhfx9jk3483uAldSxvyxswiWSa913dWDi+eW073/3FLv7j8f00u8sIhOOMTwf4/H1rUVzGCyDKZbywv4+jfeNc1e7i1FA+fF5TbqHUrEcQoLt/gv/7xIHCpH7Oq/BGaNVKrl/RxPP7+vjmI89TU2FhbNKPlMnyd39+C75gjJcOnWFiJkh5qRGPN4hKKeKwvXH43W4xsLKzhv3HRvjGI89RYS9hxONDrVTw15+56ffSdyE3r2nl9NjL/OBXr1BfbWNg1Isg5I9DPveoq6WSlw8PIghCYQeYzaTjmsVu9nYP81f/9ixVDjPDE3OoFCJ/89mbuXZpHb3D0/zotweorbRe1oj8UyK/41Nf+YZd9+61Y+LRMPFoBGe1m9LyKiDHzMQYJpud2tZOjBYblbWNRENBwoE5LGXlrL7l/VTWNl5Sl73SRSaTIRzwQS5v/GQzGQwmCxXuBlwNrSRiUaLhIGa7g8bOZfntmI2thdwVURSR0mkaO5dTVlUDgEwmo25RF+RyRMJBkrEorsY2lm+4GaVKzezkOGqtjrr2JYgKJYIg4G7pIJ1MMDk6SCwcwt3Swepb3o9CqaLS3UA0FCAc8GEudeBqakcml2NzVKDS6ogE/dgrqnG68omgUjqF3zuF1VFOdUMLGp0Bd/MiEvEoIf8sMrmcRVevoXnpCvQlZhyuWoI+L1Ojg8hFBZ2r1tN+9Zp3tdtcOHUcoyyDsIC7Q+oMBkLpNIFUCptaTavJhEwmw6nR4NRqaTAa8SeTRCWJLqsVk1KJVhRZPB8q0ooiwXQarShyjd2OcT4fQ69Q0GIyEZMkfMkkcpmMlfPemAwwHY/j0GoL+TAGhYIagwFfMslINIpcEFhlt3NNWRkaUcSp1eJPpQilUnRZrWhEEbVcTnNJCcF0mhzQWFJSmKSD6TTJbJZag4EyjYZqvR6bWk0kncaXTGLTaLipqgqnRkONwYBGFPHEYkzGYti1WjZXVRVyYy7klelp0tksqx0OxqNR7BoNd7nd2DUapFyOqXicMo2G5nmPiUDeyBOAkCQRSaepMxq5qaoKjVxOm9lMKptlOBwmmErRajZzm8uFSi6nRKmksaSE6Pw9VMrlrLbb6bBYSGWzzCWTVGg01BrzK+XxaBSDUsliiwW5IGBVq/Elk0i5HIvM5jcMh71t5DLISxWoLaY3/uxbSEttGb5QjFl/lJoKK83uMpQKOTXlFirsJhpr7MwFo0TjKTZc3YhaJWI16Vjakp84TQYNHm8Ik1HDjauaCzkrVpOOtjoH4ViKGV8YtUrBjde0FBJOZ3xhXOWWQs6KzaSjptzCjC/CwKgXrVrJzatbWb2klhK9GrvFgNcXIRxLsn55PYIgoFUrWdpaxfRcGFGU0dlYQYk+nyowF4giZbK01TmwWww0u+2YjVrCsQRef4Ryewl3b1xCmdVAS20ZSlHOiMfPxEwAl9PClk3LcNgufQXFmbFZ7FYDV7WfXxh2NJSjVSsJRxP4gjFcTjN337gEs1FbyINpr3diLdERjiUJRRPUVdmoKbcgZbKMTwcoLy2hs6mCUrOeaoeZUDTJXCBKlcPE3Ru7aK457wnSa1T4QjGaauys7KgpHF/UUI5OoyQcTTIXjOFyWrh7Y16Hs9SIyaDBF8wvbDZd20YiJeGusFy0nfxPQSIYRvjZ0dlcm/318zqKFLkSkD3+SyrkKYT34A6Rdyvf6ukhmk7zja6ui94xU+T1yWWSKFu1mOpqFlpKkSJvK4HRiStnG3WRIkWKFClS5L2DWFzXFClS5J3IitJSUpkMxTGqSJEilyDIENWiQCqTQykvDhNFihR557De6VxoCUWKFHkHkk1LIGqQva/OSP9sglSm+HvhRYoUKVKkSJF3Ltm0RHB8EkPbBkSzRuS2ZhPbh4KkMjmEosO2yBVIFSKSfxphoXaHFCnyp0DIkgplCZ6dXmglRYq85eTIIYgajF23I9eZ+P+i+I+D7iU6LQAAAABJRU5ErkJggg==\"\"\"\n",
        "slider = PIL.Image.open(io.BytesIO(base64.b64decode(slider_b64)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "co32HAdaVZWt"
      },
      "outputs": [],
      "source": [
        "ds_x_map = {\"train\": x_train, \"test\": x_test}\n",
        "ds_y_map = {\"train\": y_train, \"test\": y_test}\n",
        "ds_y_pic_map = {\"train\":y_train_pic, \"test\":y_test_pic}\n",
        "\n",
        "\n",
        "def make_run_videos(ca, num_steps, eval_bs, prefix, ds=\"test\", disable_black=False,\n",
        "                    color_missclassifications=False):\n",
        "  ds_x = ds_x_map[ds]\n",
        "  new_idx = np.random.randint(0, ds_x.shape[0]-1, size=eval_bs)\n",
        "  x = ca.initialize(ds_x[new_idx])\n",
        "  \n",
        "  if color_missclassifications:\n",
        "    yt = ds_y_map[ds][new_idx]\n",
        "    yt_pic = ds_y_pic_map[ds][new_idx]\n",
        "    yt_label = tf.argmax(yt_pic, axis=-1) # this can certainly be just broadcast.\n",
        "    live_m = tf.cast(x[:,:,:,0:1] \u003e 0.1, tf.float32)\n",
        "    total_alive = tf.reduce_sum(live_m[...,0], axis=[1,2])\n",
        "    \n",
        "\n",
        "  frames = []\n",
        "  with VideoWriter(prefix + \".mp4\") as vid:\n",
        "    \n",
        "    for i in tqdm.trange(-1, num_steps):\n",
        "      if i == -1:\n",
        "        image = classify_and_color(ca, x, disable_black=False)\n",
        "      else:\n",
        "        x = ca(x)\n",
        "        image = classify_and_color(ca, x, disable_black=disable_black)\n",
        "              \n",
        "        if color_missclassifications:\n",
        "          # figure out if they are missclassified.\n",
        "          y = ca.classify(x)\n",
        "          y_label = tf.argmax(y, axis=-1)\n",
        "\n",
        "          correct = tf.cast(tf.equal(y_label,  yt_label), tf.float32) * live_m[...,0]\n",
        "          total_correct = tf.reduce_sum(correct, axis=[1,2])\n",
        "          batch_wrong = tf.cast((total_correct / total_alive) \u003c 0.5, tf.float32)\n",
        "          \n",
        "          # Create the pictures now: 1 - live_m * color.\n",
        "          error_color = tf.constant([1.0, 0.8, 0.79])\n",
        "          correct_color = tf.constant([1.0, 1.0, 1.0])\n",
        "          # make batch_correct [BS] -\u003e [BS, 28, 28, 1]\n",
        "          batch_wrong = tf.transpose(tf.broadcast_to(batch_wrong, [1, 28, 28, eval_bs]))\n",
        "          # inefficient way of accomplishing this.\n",
        "          all_backgrounds = batch_wrong * (1- live_m) * error_color +\\\n",
        "            (1 - batch_wrong) * (1- live_m) * correct_color\n",
        "          image = image * live_m + all_backgrounds\n",
        "\n",
        "      image = zoom(tile2d(image), scale=2)\n",
        "\n",
        "      im = np.uint8(image*255)\n",
        "      vid.add(im)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P35Un4ZxYFIq"
      },
      "outputs": [],
      "source": [
        "# @title visualize original runs\n",
        "eval_bs= 100\n",
        "num_steps = 200\n",
        "\n",
        "ca = get_model()\n",
        "make_run_videos(ca, num_steps, eval_bs, \"l2_runs\", \n",
        "                color_missclassifications=True)\n",
        "mvp.ipython_display('l2_runs.mp4')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jA-gfTfiqsuC"
      },
      "outputs": [],
      "source": [
        "ds_x_map = {\"train\": x_train, \"test\": x_test}\n",
        "ds_y_map = {\"train\": y_train, \"test\": y_test}\n",
        "ds_y_pic_map = {\"train\":y_train_pic, \"test\":y_test_pic}\n",
        "\n",
        "def make_adv_run_videos(orig_ca, adv_ca, percentage_virus, num_steps, eval_bs, prefix,\n",
        "                        ds=\"test\", disable_black=False, check_at_least_one=False,\n",
        "                        place_adversaries_at=0,\n",
        "                        remove_adversaries_at=None, \n",
        "                        color_missclassifications=False,\n",
        "                        border_width = 56):\n",
        "  ds_x = ds_x_map[ds]\n",
        "  new_idx = np.random.randint(0, ds_x.shape[0]-1, size=eval_bs)\n",
        "  x = orig_ca.initialize(ds_x[new_idx])\n",
        "  live_m = tf.cast(x[:,:,:,0:1] \u003e 0.1, tf.float32)\n",
        "  if color_missclassifications:\n",
        "    yt = ds_y_map[ds][new_idx]\n",
        "    yt_pic = ds_y_pic_map[ds][new_idx]\n",
        "    yt_label = tf.argmax(yt_pic, axis=-1) # this can certainly be just broadcast.\n",
        "    total_alive = tf.reduce_sum(live_m[...,0], axis=[1,2])\n",
        "\n",
        "  if check_at_least_one:\n",
        "    m = []\n",
        "    for i in range(eval_bs):\n",
        "      mi = (np.random.uniform(size=[28, 28, 1]) \u003c percentage_virus).astype(np.float32)\n",
        "      while tf.reduce_sum(live_m[i]*mi) \u003c 0.5:\n",
        "        mi = (np.random.uniform(size=[28, 28, 1]) \u003c percentage_virus).astype(np.float32)\n",
        "      m.append(mi)\n",
        "    m = np.stack(m)\n",
        "  else:\n",
        "    m = (np.random.uniform(size=[eval_bs, 28, 28, 1]) \u003c percentage_virus).astype( \n",
        "         np.float32)\n",
        "  m_act = np.zeros_like(m)\n",
        "  frames = []\n",
        "  with VideoWriter(prefix + \".mp4\") as vid:\n",
        "    for i in tqdm.trange(-1, num_steps):\n",
        "      if i == -1:\n",
        "        image = zoom(tile2d(classify_and_color(orig_ca, x, disable_black=False)), scale=2)\n",
        "      else:\n",
        "        if i == place_adversaries_at:\n",
        "          m_act = m\n",
        "        if i == remove_adversaries_at:\n",
        "          m_act = np.zeros_like(m)\n",
        "        x_orig = orig_ca(x)\n",
        "        x_adv = adv_ca(x)\n",
        "        x = x_orig * (1. - m_act) + x_adv * m_act\n",
        "\n",
        "        image = classify_and_color(orig_ca, x, disable_black=disable_black)\n",
        "        image = image *(1 - m_act) * live_m + image * (1-live_m)\n",
        "        # Color the adversaries:\n",
        "        image += np.array([1.0, 0.0, 0.0]) * m_act * live_m\n",
        "\n",
        "        \n",
        "        if color_missclassifications:\n",
        "          # figure out if they are missclassified.\n",
        "          y = orig_ca.classify(x)\n",
        "          y_label = tf.argmax(y, axis=-1)\n",
        "\n",
        "          correct = tf.cast(tf.equal(y_label,  yt_label), tf.float32) * live_m[...,0]\n",
        "          total_correct = tf.reduce_sum(correct, axis=[1,2])\n",
        "          batch_wrong = tf.cast((total_correct / total_alive) \u003c 0.5, tf.float32)\n",
        "          \n",
        "          # Create the pictures now: 1 - live_m * color.\n",
        "          error_color = tf.constant([1.0, 0.8, 0.79])\n",
        "          correct_color = tf.constant([1.0, 1.0, 1.0])\n",
        "          # make batch_correct [BS] -\u003e [BS, 28, 28, 1]\n",
        "          batch_wrong = tf.transpose(tf.broadcast_to(batch_wrong, [1, 28, 28, eval_bs]))\n",
        "          # inefficient way of accomplishing this.\n",
        "          all_backgrounds = batch_wrong * (1- live_m) * error_color +\\\n",
        "           (1 - batch_wrong) * (1- live_m) * correct_color\n",
        "          image = image * live_m + all_backgrounds\n",
        "\n",
        "        image = zoom(tile2d(image), scale=2)\n",
        "\n",
        "      # if we remove adversaries, add a timeline.\n",
        "      # TODO: use the right modified slider.\n",
        "      if remove_adversaries_at is not None:\n",
        "        vis_extended = np.concatenate((np.ones((slider.size[1] + 20, image.shape[1], 3)), image), axis=0)\n",
        "        im = np.uint8(vis_extended*255)\n",
        "        im = PIL.Image.fromarray(im)\n",
        "        im.paste(slider, box=(0, 0))\n",
        "        draw = PIL.ImageDraw.Draw(im)\n",
        "        p_x = (((image.shape[1]-8)/num_steps)*i)\n",
        "        draw.rectangle([p_x, 1, p_x+5, 21], fill=\"#434343bd\")\n",
        "        bordered_im = PIL.Image.new(\"RGB\", (im.size[0] + border_width, im.size[1] + border_width), color='white')\n",
        "        bordered_im.paste(im, (border_width//2, border_width//2))\n",
        "        im = np.uint8(bordered_im)\n",
        "      else:\n",
        "        im = np.uint8(image*255)\n",
        "      \n",
        "      # Add the image to the video\n",
        "      vid.add(im)\n",
        "      # Add it multiple times (freeze the video) on state changes and at the end.\n",
        "      if (i == place_adversaries_at and place_adversaries_at \u003e 0) or (i == num_steps -1 ) or (\n",
        "          remove_adversaries_at is not None and i == remove_adversaries_at):\n",
        "        for _ in range(29):\n",
        "          vid.add(im)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1_U5jQYAMtIt"
      },
      "outputs": [],
      "source": [
        "# @title visualize adversarial runs with reset (10%)\n",
        "eval_bs= 100\n",
        "num_steps = 600\n",
        "palce_adversaries_at = 200\n",
        "remove_adversaries_at = 400\n",
        "\n",
        "orig_ca = get_model()\n",
        "make_adv_run_videos(orig_ca, adv_ca, 0.1, num_steps, eval_bs, \"adv_runs\", \n",
        "                    place_adversaries_at=palce_adversaries_at,\n",
        "                    remove_adversaries_at=remove_adversaries_at,\n",
        "                    color_missclassifications=True)\n",
        "mvp.ipython_display('adv_runs.mp4')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SA8gROioqvde"
      },
      "outputs": [],
      "source": [
        "# @title visualize adversarial runs (1% adversaries)\n",
        "\n",
        "eval_bs= 100\n",
        "num_steps = 200\n",
        "\n",
        "orig_ca = get_model()\n",
        "make_adv_run_videos(orig_ca, adv_ca, 0.01, num_steps, eval_bs, \"adv_runs\", \n",
        "                    check_at_least_one=True, color_missclassifications=True)\n",
        "mvp.ipython_display('adv_runs.mp4')"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "machine_shape": "hm",
      "name": "Adversarial Reprogramming of MNIST CA",
      "provenance": [
        {
          "file_id": "14xFg14OUbFI6es1LMuIoDKmMh8eGDJJY",
          "timestamp": 1615994747888
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/self_organising_systems/notebooks/mnist_ca.ipynb",
          "timestamp": 1600860840150
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
